#!/usr/bin/env python3

import configparser
import xml.etree.ElementTree as ET
import urllib.request
import os
import sys
import sqlite3
import gzip
import rpm

rpmdepsearch_conf=rpm.expandMacro('%{_sysconfdir}'+'/rpmdepsearch/rpmdepsearch.conf')
primary_dir=rpm.expandMacro('%{_localstatedir}'+'/cache/rpmdepsearch/')

def check_primary_dir():
    """
    check if cache dir exists,return True for exists,False for not and create.
    """
    if not os.path.exists(primary_dir):
        os.makedirs(primary_dir)
        return False
    return True

def download_db():
    try:
        # load repo_urls
        config = configparser.ConfigParser()
        config.read(rpmdepsearch_conf)
        repo_urls=config['repos']['baseurl']
        os.path.expandvars(repo_urls)
        repo_urls=repo_urls.split("\n")
        if len(repo_urls)==0:
            print("No repo url,please add url to {}".format(rpmdepsearch_conf))
            return
        check_primary_dir()
    except Exception as err:
        print("No repo url,please add url to {}".format(rpmdepsearch_conf))
        return
    for repo_url in repo_urls:
        repo_url = repo_url.strip("repodata/")
        repo_md_url=repo_url.strip('/')+'/repodata/repomd.xml'
        print("downloading from {}".format(repo_md_url))
        try:
            response = urllib.request.urlopen(repo_md_url)
            data = response.read()
            tree = ET.ElementTree(ET.fromstring(data))
            root = tree.getroot()
            # search for data type="primary_db"
            xmlns="{http://linux.duke.edu/metadata/repo}"
            primarys=root.findall(xmlns+"data[@type='primary_db']")
            if len(primarys)==0:
                primarys=root.findall(xmlns+"data[@type='primary']")
            for data in primarys:
                location = data.find(xmlns+'location')
                if location is None:
                    continue
                href = location.get('href')
                primary_url=f"{repo_url.strip('/')}/{href.strip('/')}"
                urllib.request.urlretrieve(primary_url,primary_dir+href.strip('repodata/'))
                print("done")
        except Exception as err:
            print("download from {} failed".format(primary_url))
            print(err)
    get_sqlite_file()
    print('make primary_db complete')

def get_sqlite_file():
    try:
        #unzip
        sqlite_cmd = "bzip2 -d %s*.sqlite.bz2" % primary_dir
        os.system(sqlite_cmd)
    except Exception as e:
        print("except in bzip2 "+e)

def clear_db():
    #chear database
    check_primary_dir()
    for file in os.listdir(primary_dir):
        file_path = os.path.join(primary_dir, file)
        if file.endswith('-primary.sqlite') or file.endswith('-primary.xml.gz'):
            os.remove(file_path)
    print("database cleared")



def query_all_required(_package:str):
    results=list()
    if not check_primary_dir():
        print("primary_dir does not exist,run reset")
        return
    for file in os.listdir(primary_dir):
        file_path = os.path.join(primary_dir, file)
        if file.endswith('-primary.sqlite'):
            results.extend(query_one_required_sqlite(_package,file_path))
        elif file.endswith('-primary.xml.gz'):
            results.extend(query_one_required_xml(_package,file_path)) 
    if len(results)>0:
        print("%s is required by:" % _package)
        for result in results:
            print("\t"+result)
    else:
         print("No package requires %s" % _package)
    print()
    return results

def query_one_required_sqlite(_package:str,_file_path:str) -> list[str]:
        conn = sqlite3.connect(_file_path)
        cursor=conn.cursor()
        cursor.execute("SELECT DISTINCT p.name FROM requires r JOIN packages p ON r.pkgKey = p.pkgKey WHERE r.name = ?",
                        (_package,))
        fetch=cursor.fetchall()
        if fetch==None:
            return list()
        cursor.close()
        conn.close()
        return [x[0] for x in fetch]

def query_one_required_xml(required_package:str,_file_path:str) -> list[str]:
    primary = gzip.open(_file_path, 'r')
    primary_tree = ET.parse(primary)
    xmlns={'rpm': 'http://linux.duke.edu/metadata/rpm',}
    xmlns_common="{http://linux.duke.edu/metadata/common}"
    result=set()
    for package in primary_tree.findall(xmlns_common+"package"):
        package_name=package.find(xmlns_common+'name').text
        requires=package.findall('.//rpm:requires//rpm:entry[@name=\'{}\']'.format(required_package), xmlns)
        if len(requires)>0:
            result.add(package_name)
    return list(result)

def help():
    print("This is help")
    return 0
def main():
    if len(sys.argv)==1:
        return help()
    if sys.argv[1]=='setup':
        clear_db()
        download_db()
        return 0;
    for arg in sys.argv[1:]:
        query_all_required(arg)
    return 0



if __name__=='__main__':
    main()